{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Encoder Customization\n",
    "\n",
    "In this notebook we will cover a tutorial for the flexible encoders!\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/ai4co/rl4co/blob/main/examples/modeling/3-change-encoder.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!\n",
    "\n",
    "> Note: You may need to restart the runtime in Colab after this\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install rl4co[graph] # include torch-geometric\n",
    "\n",
    "## NOTE: to install latest version from Github (may be unstable) install from source instead:\n",
    "# !pip install git+https://github.com/ai4co/rl4co.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl4co.envs import CVRPEnv\n",
    "\n",
    "from rl4co.models.zoo import AttentionModel\n",
    "from rl4co.utils.trainer import RL4COTrainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A default minimal training script\n",
    "\n",
    "Here we use the CVRP environment and AM model as a minimal example of training script. By default, the AM is initialized with a Graph Attention Encoder, but we can change it to anything we want."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n",
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n",
      "Using 16bit Automatic Mixed Precision (AMP)\n",
      "GPU available: True (cuda), used: True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Encoder: GraphAttentionEncoder\n"
     ]
    }
   ],
   "source": [
    "# Init env, model, trainer\n",
    "env = CVRPEnv(generator_params=dict(num_loc=20))\n",
    "\n",
    "model = AttentionModel(\n",
    "    env, \n",
    "    baseline='rollout',\n",
    "    train_data_size=100_000, # really small size for demo\n",
    "    val_data_size=10_000\n",
    ")\n",
    " \n",
    "trainer = RL4COTrainer(\n",
    "    max_epochs=3, # few epochs for demo\n",
    "    accelerator='gpu',\n",
    "    devices=1,\n",
    "    logger=False,\n",
    ")\n",
    "\n",
    "# By default the AM uses the Graph Attention Encoder\n",
    "print(f'Encoder: {model.policy.encoder._get_name()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /datasets/home/botu/Dev/rl4co/notebooks/tutorials/checkpoints exists and is not empty.\n",
      "val_file not set. Generating dataset instead\n",
      "test_file not set. Generating dataset instead\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "\n",
      "  | Name     | Type                 | Params\n",
      "--------------------------------------------------\n",
      "0 | env      | CVRPEnv              | 0     \n",
      "1 | policy   | AttentionModelPolicy | 694 K \n",
      "2 | baseline | WarmupBaseline       | 694 K \n",
      "--------------------------------------------------\n",
      "1.4 M     Trainable params\n",
      "0         Non-trainable params\n",
      "1.4 M     Total params\n",
      "5.553     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3db02ec8f6dc4913a26462bbc1a851e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.\n",
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51e7a0cfc0da42b990392b519816193a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0e4efe5874e34ed5baf615371789ee9e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "60032ece3b2441dcb5dd3e23ed22b116",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28e10dcb8b234ec9996504c451a6dd27",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=3` reached.\n"
     ]
    }
   ],
   "source": [
    "# Train the model\n",
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Change the Encoder\n",
    "\n",
    "In RL4CO, we provides two graph neural network encoders: *Graph Convolutionsal Network* (GCN) encoder and *Message Passing Neural Network* (MPNN) encoder. In this tutorial, we will show how to change the encoder. \n",
    "\n",
    "> Note: while we provide these examples, you can also implement your own encoder and use it in RL4CO! For instance, you may use different encoders (and decoders) to solve problems that require e.g. distance matrices as input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Before we init, we need to install the graph neural network dependencies\n",
    "# !pip install rl4co[graph]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n",
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n",
      "Using 16bit Automatic Mixed Precision (AMP)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "# Init the model with different encoder\n",
    "from rl4co.models.nn.graph.gcn import GCNEncoder\n",
    "from rl4co.models.nn.graph.mpnn import MessagePassingEncoder\n",
    "\n",
    "gcn_encoder = GCNEncoder(\n",
    "    env_name='cvrp', \n",
    "    embed_dim=128,\n",
    "    num_nodes=20, \n",
    "    num_layers=3,\n",
    ")\n",
    "\n",
    "mpnn_encoder = MessagePassingEncoder(\n",
    "    env_name='cvrp', \n",
    "    embed_dim=128,\n",
    "    num_nodes=20, \n",
    "    num_layers=3,\n",
    ")\n",
    "\n",
    "model = AttentionModel(\n",
    "    env, \n",
    "    baseline='rollout',\n",
    "    train_data_size=100_000, # really small size for demo\n",
    "    val_data_size=10_000, \n",
    "    policy_kwargs={\n",
    "        'encoder': gcn_encoder # gcn_encoder or mpnn_encoder\n",
    "    }\n",
    ")\n",
    " \n",
    "trainer = RL4COTrainer(\n",
    "    max_epochs=3, # few epochs for demo\n",
    "    accelerator='gpu',\n",
    "    devices=1,\n",
    "    logger=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /datasets/home/botu/Dev/rl4co/notebooks/tutorials/checkpoints exists and is not empty.\n",
      "val_file not set. Generating dataset instead\n",
      "test_file not set. Generating dataset instead\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "\n",
      "  | Name     | Type                 | Params\n",
      "--------------------------------------------------\n",
      "0 | env      | CVRPEnv              | 0     \n",
      "1 | policy   | AttentionModelPolicy | 148 K \n",
      "2 | baseline | WarmupBaseline       | 148 K \n",
      "--------------------------------------------------\n",
      "297 K     Trainable params\n",
      "0         Non-trainable params\n",
      "297 K     Total params\n",
      "1.191     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81c30fe25912497bb53cfb492810c655",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.\n",
      "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f85014ec77f147e594aef18f4d564499",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c2c3cea2780c41f5b8758e31117e2193",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "af0a6ff3bd4b4cbc8d04e2de5acca2bd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d56a1f56006e41dc976851912dd3d74f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=3` reached.\n"
     ]
    }
   ],
   "source": [
    "# Train the model\n",
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Or you want to create your own encoder\n",
    "\n",
    "If you want to create a new encoder, you may want to follow the following base class to create the encoder class with the folowing components:\n",
    "\n",
    "1. RL4CO provides the `env_init_embedding` method for each environment. You may want to use it to get the initial embedding of the environment.\n",
    "2. `h` and `init_h` as return hidden features have the shape `([batch_size], num_node, hidden_size)`\n",
    "3. In RL4CO, we put the graph neural network encoders in the `rl4co/models/nn/graph` folder. You may want to put your customized encoder to the same folder. Feel free to send a PR to add your encoder to RL4CO!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import necessary packages\n",
    "import torch.nn as nn\n",
    "from torch import Tensor\n",
    "from tensordict import TensorDict\n",
    "from typing import Tuple, Union\n",
    "from rl4co.models.nn.env_embeddings import env_init_embedding\n",
    "\n",
    "\n",
    "class BaseEncoder(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            env_name: str,\n",
    "            embed_dim: int,\n",
    "            init_embedding: nn.Module = None,\n",
    "        ):\n",
    "        super(BaseEncoder, self).__init__()\n",
    "        self.env_name = env_name\n",
    "        \n",
    "        # Init embedding for each environment\n",
    "        self.init_embedding = (\n",
    "            env_init_embedding(self.env_name, {\"embed_dim\": embed_dim})\n",
    "            if init_embedding is None\n",
    "            else init_embedding\n",
    "        )\n",
    "\n",
    "    def forward(\n",
    "        self, td: TensorDict, mask: Union[Tensor, None] = None\n",
    "    ) -> Tuple[Tensor, Tensor]:\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            td: Input TensorDict containing the environment state\n",
    "            mask: Mask to apply to the attention\n",
    "\n",
    "        Returns:\n",
    "            h: Latent representation of the input\n",
    "            init_h: Initial embedding of the input\n",
    "        \"\"\"\n",
    "        init_h = self.init_embedding(td)\n",
    "        h = None\n",
    "        return h, init_h"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl4co",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
